{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {func}`~tvm.relay.data_dep_optimization.simplify_fc_transpose`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "\n",
    "\n",
    "import tvm\n",
    "from tvm.ir import IRModule\n",
    "from tvm import relay\n",
    "from tvm.relay.data_dep_optimization import simplify_fc_transpose\n",
    "\n",
    "def run_func(func, params, x):\n",
    "    with tvm.transform.PassContext(opt_level=3):\n",
    "        lib = relay.build(func, \"llvm\", params=params)\n",
    "\n",
    "    from tvm.contrib import graph_executor\n",
    "\n",
    "    dev = tvm.cpu(0)\n",
    "    dtype = \"float32\"\n",
    "    m = graph_executor.GraphModule(lib[\"default\"](dev))\n",
    "    # set inputs\n",
    "    m.set_input(\"data\", tvm.nd.array(x.astype(dtype)))\n",
    "    # execute\n",
    "    m.run()\n",
    "    # get outputs\n",
    "    tvm_output = m.get_output(0)\n",
    "    return tvm_output.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 32), float32], %w1: Tensor[(32, 64), float32], %w2: Tensor[(64, 16), float32]) {\n",
      "  %0 = nn.relu(%data);\n",
      "  %1 = transpose(%w1, axes=[1, 0]);\n",
      "  %2 = nn.dense(%0, %1, units=None);\n",
      "  %3 = nn.relu(%2);\n",
      "  %4 = transpose(%w2, axes=[1, 0]);\n",
      "  nn.dense(%3, %4, units=None)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "data = relay.var(\"data\", shape=(1, 32), dtype=\"float32\")\n",
    "x = relay.nn.relu(data)\n",
    "w1 = relay.var(\"w1\", shape=(32, 64), dtype=\"float32\")\n",
    "y = relay.nn.dense(x, relay.transpose(w1, axes=[1, 0]))\n",
    "z = relay.nn.relu(y)\n",
    "w2 = relay.var(\"w2\", shape=(64, 16), dtype=\"float32\")\n",
    "zz = relay.nn.dense(z, relay.transpose(w2, axes=[1, 0]))\n",
    "func = relay.Function(relay.analysis.free_vars(zz), zz)\n",
    "print(func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n"
     ]
    }
   ],
   "source": [
    "params = {\n",
    "    \"w1\": tvm.nd.array(np.random.uniform(-1, 1, (32, 64)).astype(\"float32\")),\n",
    "    \"w2\": tvm.nd.array(np.random.uniform(-1, 1, (64, 16)).astype(\"float32\")),\n",
    "}\n",
    "x_np = np.random.randn(1, 32).astype(\"float32\")\n",
    "old_result = run_func(func, params, x_np)\n",
    "\n",
    "new_func, new_params = simplify_fc_transpose.convert(func, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, %w1.T: Tensor[(64, 32), float32] /* ty=Tensor[(64, 32), float32] */, %w2.T: Tensor[(16, 64), float32] /* ty=Tensor[(16, 64), float32] */) -> Tensor[(1, 16), float32] {\n",
      "  %0 = nn.relu(%data) /* ty=Tensor[(1, 32), float32] */;\n",
      "  %1 = nn.dense(%0, %w1.T, units=None) /* ty=Tensor[(1, 64), float32] */;\n",
      "  %2 = nn.relu(%1) /* ty=Tensor[(1, 64), float32] */;\n",
      "  nn.dense(%2, %w2.T, units=None) /* ty=Tensor[(1, 16), float32] */\n",
      "} /* ty=fn (Tensor[(1, 32), float32], Tensor[(64, 32), float32], Tensor[(16, 64), float32]) -> Tensor[(1, 16), float32] */\n"
     ]
    }
   ],
   "source": [
    "print(new_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_result = run_func(new_func, new_params, x_np)\n",
    "np.testing.assert_allclose(old_result, new_result, atol=1e-5, rtol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
